{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 测试我们的算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = iris.data\n",
    "y = iris.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(150, 4)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(150,)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 31,  61,  82, 113, 127,  30,  63, 124,  27,  10,  18,   0,   9,\n",
       "       107,  22,   6,  46, 131, 104, 148,  15,  81,  43,  47,   1,  54,\n",
       "        33, 117,  76,  52,  70,   7,  41, 139, 100, 125,  56, 116,  92,\n",
       "       149,  26,  24,  72, 130,  94, 119,  34,  45,  35,   3,  32, 118,\n",
       "       105,  17,  73,  90,  23,  86,  64, 109,  38, 137, 132,  55, 106,\n",
       "       142,  16,  68, 135,  91, 120,  51,   2,  21,  84, 102,  11,  93,\n",
       "        14, 121,  74,  49,  99,  62, 101,  71,  37,   5,  75, 146, 110,\n",
       "        87, 143,   8,  59,  28,  97,  83,  78,  69,  29,  88,  20, 123,\n",
       "        39, 126,  85,  98,  53, 128, 145, 108,  96,  19,  44,  60,  13,\n",
       "       147,  77,  40,  79,  65,  58, 141,  48,  89, 122,  95, 115, 133,\n",
       "       140, 136,  50,  36,  25,  57, 103,  12, 129,  80, 144,   4,  67,\n",
       "       138,  42, 112, 111, 114, 134,  66])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shuffle_indexes = np.random.permutation(len(X))\n",
    "shuffle_indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_ratio = 0.2\n",
    "test_size = int(len(x) * test_ratio)\n",
    "test_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 31,  61,  82, 113, 127,  30,  63, 124,  27,  10,  18,   0,   9,\n",
       "       107,  22,   6,  46, 131, 104, 148,  15,  81,  43,  47,   1,  54,\n",
       "        33, 117,  76,  52])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_indexes = shuffle_indexes[:test_size]\n",
    "train_indexes = shuffle_indexes[test_size:]\n",
    "test_indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X[train_indexes]\n",
    "y_train = y[train_indexes]\n",
    "\n",
    "X_test = X[test_indexes]\n",
    "y_test = y[test_indexes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(120, 4)\n",
      "(120,)\n"
     ]
    }
   ],
   "source": [
    "print(X_train.shape)\n",
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30, 4)\n",
      "(30,)\n"
     ]
    }
   ],
   "source": [
    "print(X_test.shape)\n",
    "print(y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def train_test_split(X, y, test_ratio=0.2, seed=None):\n",
    "    \"\"\"将X和y按照test_ratio分割成X_train，X_test，y_train，y_test\"\"\"\n",
    "    \n",
    "    assert X.shape[0] == y.shape[0], \"the size of X must be equal to the size of y\"\n",
    "    assert 0.0 <= test_ratio <= 1.0, \"test_ration must be valid\"\n",
    "    \n",
    "    if seed:\n",
    "        np.random.seed(seed)\n",
    "        \n",
    "    shuffle_indexes = np.random.permutation(len(X))\n",
    "    \n",
    "    test_size = int(len(x) * test_ratio)\n",
    "    test_indexes = shuffle_indexes[:test_size]\n",
    "    train_indexes = shuffle_indexes[test_size:]\n",
    "    \n",
    "    X_train = X[train_indexes]\n",
    "    y_train = y[train_indexes]\n",
    "\n",
    "    X_test = X[test_indexes]\n",
    "    y_test = y[test_indexes]\n",
    "    \n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用我们的算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    " X_train, X_test, y_train, y_test = train_test_split(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(120, 4)\n",
      "(120,)\n"
     ]
    }
   ],
   "source": [
    "print(X_train.shape)\n",
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30, 4)\n",
      "(30,)\n"
     ]
    }
   ],
   "source": [
    "print(X_test.shape)\n",
    "print(y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from math import sqrt\n",
    "from collections import Counter\n",
    "\n",
    "class kNNClassifier:\n",
    "    \n",
    "    def __init__(self, k):\n",
    "        \"\"\"初始化kNN分类器\"\"\"\n",
    "        assert k >= 1, \"K must be valid!\"\n",
    "        self.k = k\n",
    "        self._X_train = None\n",
    "        self._y_train = None\n",
    "        \n",
    "    def fit(self, X_train, y_train):\n",
    "        \"\"\"根据训练数据集X_train和y_train训练kNN分类器\"\"\"\n",
    "        assert X_train.shape[0] == y_train.shape[0], \"the size of X_train must equal to the size of y_train\"\n",
    "        assert self.k <= X_train.shape[0], \"the size of X_train must be at least k\"\n",
    "    \n",
    "        self._X_train = X_train\n",
    "        self._y_train = y_train\n",
    "        return self\n",
    "    \n",
    "    def predict(self, X_predict):\n",
    "        \"\"\"给定待预测数据集X_predict, 返回表示X_predict的结果向量\"\"\"\n",
    "        assert self._X_train is not None and self._X_train is not None, \"must fit before predict\"\n",
    "        assert self._X_train.shape[1] == X_predict.shape[1], \"the feature number of X_predict must be equal to X_train\"\n",
    "        \n",
    "        y_predict = [self._predict(x) for x in X_predict]\n",
    "        return np.array(y_predict)\n",
    "    \n",
    "    def _predict(self, x):\n",
    "        \"\"\"给定单个待测数据x，返回x的预测结果\"\"\"\n",
    "        assert self._X_train.shape[1] == x.shape[0], \"the feature number of x must be equal to X_train\"\n",
    "        \n",
    "        distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._X_train]\n",
    "        nearst = np.argsort(distances)\n",
    "    \n",
    "        topK_y = [self._y_train[i] for i in nearst[:self.k]]\n",
    "        votes = Counter(topK_y)\n",
    "    \n",
    "        return votes.most_common(1)[0][0]\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return \"KNN(k=%d)\" % self.k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_Knn_clf = kNNClassifier(k = 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNN(k=3)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_Knn_clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict = my_Knn_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 2, 1, 1, 1, 0, 1, 0, 0, 2, 0, 1, 0, 0, 0, 0, 2, 0, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 0])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 2, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 2, 0, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 0])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "29"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(y_predict == y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9666666666666667"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(y_predict == y_test) / len(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### sklearn中的train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(120, 4)\n",
      "(120,)\n"
     ]
    }
   ],
   "source": [
    "print(X_train.shape)\n",
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(30, 4)\n",
      "(30,)\n"
     ]
    }
   ],
   "source": [
    "print(X_test.shape)\n",
    "print(y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_Knn_clf = kNNClassifier(k = 3)\n",
    "my_Knn_clf.fit(X_train, y_train)\n",
    "y_predict = my_Knn_clf.predict(X_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
